import numpy as np
import pickle

print("load data")
file = 'vqa.data'
print(file)
with open(file, "rb") as f:
    vqa_data = pickle.load(f)
label = []
ts = []
vs = []
for v, t, l in vqa_data:
    vs.append(v)
    ts.append(t)
    label.append(l)
vs = np.array(vs)
ts = np.array(ts)
label = np.array(label)
print("load data success!")

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(300, 200),
            nn.ReLU(),
            nn.Linear(200, 2)
        )#74

        # self.t_backbone = nn.Sequential(
        #     nn.Linear(100, 100),
        #     # nn.ReLU()
        # )
        # self.v_backbone = nn.Sequential(
        #     nn.Linear(200, 100),
        #     # nn.ReLU()
        # )
        # self.clf = nn.Linear(200, 2)
        # self.backbone = nn.Sequential(
        #     nn.Linear(3000, 1000),
        #     nn.ReLU(),
        #     nn.Linear(1000, 10),
        #     nn.ReLU(),
        #     nn.Linear(10, 2)
        #     )
        # print(self.backbone)
    def forward(self, v, t):
        # v = self.v_backbone(v)
        # t = self.t_backbone(t)

        x = torch.cat([v, t], dim=1)
        # x = self.clf(x)
        x = self.backbone(x)
        return x

print("prepera model and optimizer")
best_mm = 0
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, weight_decay=1e-4)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1000, gamma = 0.1)

vs = torch.from_numpy(vs).float()
ts = torch.from_numpy(ts).float()
label = torch.from_numpy(label)

train_vs = vs[:4000]
test_vs = vs[4000:]
train_ts = ts[:4000]
test_ts = ts[4000:]
train_label = label[:4000]
test_label = label[4000:]
print(train_vs.shape, train_ts.shape, train_label.shape)

for i in range(1000):
    output = model(train_vs, train_ts)
    loss = torch.nn.CrossEntropyLoss()(output, train_label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    #scheduler.step()
    if i%100 == 0:
        _, pred = output.max(1)
        train_correct = (pred == train_label).sum().item()
        print("train acc:", train_correct/4000)
        with torch.no_grad():
            test_output = model(test_vs, test_ts)
            _, pred = test_output.max(1)
            eval_correct = (pred == test_label).sum().item()
            print("test acc:", eval_correct/1000)
            if eval_correct/1000 > best_mm:
                best_mm = eval_correct/1000





class VModel(nn.Module):
    def __init__(self):
        super(VModel, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 2)
        )#74

        print(self.backbone)
    def forward(self, v):

        x = self.backbone(v)
        return x


model = VModel()
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, weight_decay=1e-4)
best_v = 0

for i in range(1000):
    output = model(train_vs)
    loss = torch.nn.CrossEntropyLoss()(output, train_label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        _, pred = output.max(1)
        train_correct = (pred == train_label).sum().item()
        print("v train acc:", train_correct/4000)
        with torch.no_grad():
            test_output = model(test_vs)
            _, pred = test_output.max(1)
            eval_correct = (pred == test_label).sum().item()
            print("v test acc:", eval_correct/1000)
            best_v = eval_correct/1000 if eval_correct/1000>best_v else best_v

class TModel(nn.Module):
    def __init__(self):
        super(TModel, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 2)
        )#74

        print(self.backbone)
    def forward(self, v):

        x = self.backbone(v)
        return x


model = TModel()
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, weight_decay=1e-4)
best_t = 0

for i in range(1000):
    output = model(train_ts)
    loss = torch.nn.CrossEntropyLoss()(output, train_label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        _, pred = output.max(1)
        train_correct = (pred == train_label).sum().item()
        print("t train acc:", train_correct/4000)
        with torch.no_grad():
            test_output = model(test_ts)
            _, pred = test_output.max(1)
            eval_correct = (pred == test_label).sum().item()
            print("t test acc:", eval_correct/1000)
            best_t = eval_correct/1000 if eval_correct/1000 > best_t else best_t

print("file:{}, best_mm:{}, best_v:{}, best_t:{}".format(file, best_mm, best_v, best_t))

